Image inpainting is an important problem in computer vsision. Its is rather challening, since not only do we need to understand the contents of the image, we need to actually undersand the context of the image to be able to correctly fill in a missing part.
In this project, we implement a version of image inpaining inspired by the paper Context Encoders by Deepak Pathek et al.. We build an autoencoder struture, where we encode images into a latent space that respects the "context" of the image, and then decode it geneerating the inpainted section. We notice similar problems that the authors of the papers noticed, with the L2 loss funciton leading to blurry inpainting results. However, instead of taking the approach outlined in their paper of using an adversarial networl to prove an adverserial loss, we invesitage alternate approaches to get rid of the blurring, such as expermingin with different loss functions that pay more attention to the structure of the image, such as SSIM, creating a custom loss based on image gradients to detect edges and constrcut a loss based on that.
Furthermore, we inversitate out the latent space of our context encoder differs from the latent space of resnet by visulizing the high dimensional latent space using a principle component analyusis, and find that the latent space is meaningfully different, which explains why we intitialyl couldn't generate inpaintings using just the encoder from resnet.
// TODO: talk about black hole in the center of the image
PyTorch - Used for autoencoder trainin
TorchVision - Used for image transformations and data loading
MatPlotLib - Used for visulziations
Pytorch_SSIM - Used for implementing the SSIM loss function as part of one of our experiments.
SkLearn - For principal component analysis
tqdm - Used for displaying progress bars when training
import torch
import torchvision.transforms as T
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn.functional as F
import torch.nn as nn
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import os
from PIL import Image
import numpy as np
from IPython.display import clear_output
from tqdm import tqdm
from pytorch_ssim import pytorch_ssim
## The dataset we use is the imagenet dataset
# You can downlaod the dataset using the command: kaggle competitions download -c imagenet-object-localization-challenge
# Update the following variable with your dataset location
DATASET_PATH = '/pub2/imagenet/ILSVRC/Data/CLS-LOC/train'
IMAGE_SIZE = 128 # Size of images fed into the model, we make the images smaller so we can train smaller and quicker models, our results should extend with bigger images as well
CENTER_SIZE = 64 # The amount of
BATCH_SIZE = 500
# The imagenet dataset has over a thousdan different classes. To train a model that works well with all clases would take more time and computation resoruces that we had avarible for this project. So we restrict our model to only work with the fruit classes. There are still different enough classes that our model can meaningfully learn to infill different types of object, while at the same time there are few categories, so we need less training time for our model.
SELECTED_CLASSES = {
# 858: "Tile Roof", # Tile Roof
# 937: "Broccoli",
# 980: "Volcano",
987: "Corn",
949: 'Strawberry',
948: "Apple",
954: "Banana",
953: "Pineapple",
951: "Lemon",
950: 'Orange',
967: 'Pomegranate'
}
The following are some utility methods that will be helpful later in our code.
import math
def display_image(img):
with torch.no_grad():
img = img.permute((1, 2, 0))
plt.imshow(img.cpu().numpy())
plt.show()
def display_images(imgs, images_per_row=2, title=None):
with torch.no_grad():
rows = math.ceil(len(imgs) / images_per_row)
fig = plt.figure(figsize=(images_per_row*2, rows*2))
if title:
fig.suptitle(title, fontsize=10)
for i, img in enumerate(imgs):
plt.subplot(rows, images_per_row, i+1)
plt.axis('off')
plt.tight_layout()
img = img.permute((1, 2, 0))
plt.imshow(img.cpu().numpy())
plt.show()
def get_center_of_image(img, size=64):
return T.CenterCrop((size, size))(img)
def get_img_without_center(img):
left = (IMAGE_SIZE - CENTER_SIZE) // 2
right = (IMAGE_SIZE - CENTER_SIZE) // 2 + CENTER_SIZE
if(len(img.shape) == 3):
img[ :, left:right, left:right ] = 0
elif(len(img.shape) == 4):
img[:, :, left:right, left:right ] = 0
else:
raise RuntimeError("Wrong shape to get_img_without_center!")
return img
class CenterImageRemoval(object):
def __init__(self, image_size, center_size):
self.image_size = image_size
self.center_size = center_size
def __call__(self, img):
left = (self.image_size - self.center_size) // 2
right = (self.image_size - self.center_size) // 2 + self.center_size
if(len(img.shape) == 3):
assert img.shape[1] == img.shape[2] == self.image_size
img[ :, left:right, left:right ] = 0
elif(len(img.shape) == 4):
assert img.shape[2] == img.shape[3] == self.image_size
img[:, :, left:right, left:right ] = 0
return img
class AddImageCenter(object):
def __init__(self, image_size, center_size):
self.image_size = image_size
self.center_size = center_size
def __call__(self, img, center):
assert center.shape[1] == center.shape[2] == self.center_size
left = (self.image_size - self.center_size) // 2
right = (self.image_size - self.center_size) // 2 + self.center_size
if(len(img.shape) == 3):
assert img.shape[1] == img.shape[2] == self.image_size
img[ :, left:right, left:right ] = center
elif(len(img.shape) == 4):
assert img.shape[2] == img.shape[3] == self.image_size
img[:, :, left:right, left:right ] = center
return img
class JointRandomResizeCrop(object):
def __init__(self, size: int, minimum_scale, maximum_scale):
"""
params:
size (int) : size of the center crop
"""
self.size = size
self.min_scale = minimum_scale
self.max_scale = maximum_scale
def __call__(self, img, target):
scale = (self.max_scale - self.min_scale)*random.random() + self.min_scale
_, height, width = img.shape
new_h, new_w = int(height*scale), int(width*scale)
target = target.reshape((1, height, width))
resized_img = T.functional.resize(img, [new_h, new_w])
resized_target = T.functional.resize(target, [new_h, new_w])
crop_size = min(self.size, new_h, new_w)
top = random.randint(0, new_h - crop_size)
left = random.randint(0, new_w - crop_size)
resized_img = T.functional.crop(resized_img, top, left, crop_size, crop_size)
resized_target = T.functional.crop(resized_target, top, left, crop_size, crop_size)
resized_img = T.functional.resize(resized_img, size=(self.size, self.size))
resized_target = T.functional.resize(resized_target, size=(self.size, self.size))
resized_target = resized_target.reshape((self.size, self.size))
return (resized_img, resized_target)
def save_model(encoder, decoder, epoch, optimizer, path):
state = {
'epoch': epoch,
'encoder': encoder.state_dict(),
'decoder': decoder.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(state, path)
def load_model(encoder, decoder, optimizer, path):
state = torch.load(path)
encoder.load_state_dict(state["encoder"])
decoder.load_state_dict(state["decoder"])
optimizer.load_state_dict(state["optimizer"])
return state["epoch"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
We split our dataset into 3 splits.
We create a custom PyTorch DataSet (CustomDataSet) class, that supports the special transformtion we want to apply to images (blackening the middle) but not the ground truth.
train_tranforms = T.Compose([
T.Resize(size=(IMAGE_SIZE, IMAGE_SIZE)),
T.RandomResizedCrop(size=(IMAGE_SIZE, IMAGE_SIZE)),
T.ToTensor(),
T.RandomHorizontalFlip(p=0.5),
T.ConvertImageDtype(torch.float32),
])
val_tranforms = T.Compose([
T.Resize(size=(IMAGE_SIZE, IMAGE_SIZE)),
T.ToTensor(),
T.ConvertImageDtype(torch.float32),
])
# extra transformation not applied to ground truth images (removing the center of the image)
non_gt_extra_transforms = T.Compose([
CenterImageRemoval(IMAGE_SIZE, CENTER_SIZE),
])
sanity_transforms = T.Compose([
T.Resize(size=(IMAGE_SIZE, IMAGE_SIZE)),
T.CenterCrop(size=(IMAGE_SIZE, IMAGE_SIZE)),
T.ToTensor(),
T.ConvertImageDtype(torch.float32),
])
class CustomDataSet(Dataset):
def __init__(self, root_dir, classes, transform=None, non_gt_extra_transforms=None, training_split=0.8, is_train=True):
self.root_dir = root_dir
self.transform = transform
self.non_gt_extra_transforms = non_gt_extra_transforms
selected_class_indexes = list(classes.keys())
dirs = np.array(sorted(os.listdir(root_dir)))[selected_class_indexes]
self.images = []
self.cache = {}
self.is_train = is_train
for i, dir in enumerate(dirs):
imgs = sorted(os.listdir(os.path.join(self.root_dir, dir)))
train_split = int(len(imgs)*training_split)
if is_train:
imgs = imgs[0:train_split]
else:
imgs = imgs[train_split:]
for img in imgs:
self.images.append((os.path.join(self.root_dir, dir, img), selected_class_indexes[i]))
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_name, class_label = self.images[idx]
image = None
if img_name not in self.cache:
self.cache[img_name] = Image.open(img_name).convert('RGB')
image = self.cache[img_name]
if self.transform:
transformed_image = self.transform(image)
else:
transformed_image = None
gt_image = torch.clone(transformed_image)
if non_gt_extra_transforms:
transformed_image = non_gt_extra_transforms(transformed_image)
return transformed_image, gt_image, class_label
# Due to the way the dataset is strucutred, I can't use the offical validation images provided
# Therefore, I'm splitting the provided training_split from the dataset, 90% for training, and 10% for validation
train_dataset = CustomDataSet(root_dir=DATASET_PATH,
transform=train_tranforms,
non_gt_extra_transforms=non_gt_extra_transforms,
classes=SELECTED_CLASSES,
is_train=True,
training_split=0.9
)
validation_dataset = CustomDataSet(root_dir=DATASET_PATH,
transform=val_tranforms,
non_gt_extra_transforms=non_gt_extra_transforms,
classes=SELECTED_CLASSES,
is_train=False
)
# Like the training dataset but only has 1 image
sanity_dataset = CustomDataSet(root_dir=DATASET_PATH,
transform=sanity_transforms,
non_gt_extra_transforms=non_gt_extra_transforms,
classes=SELECTED_CLASSES,
is_train=True,
training_split=0.9
)
sanity_dataset = torch.utils.data.Subset(sanity_dataset, [1])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True)
sanity_loader = DataLoader(sanity_dataset, batch_size=1, shuffle=True)
sample_batch = next(iter(train_loader))
sample_cutout, sample_gt, sample_class_idx = sample_batch
def display_sample_batch():
images = []
for i in range(20):
images += [sample_cutout[i], sample_gt[i]]
display_images(images, images_per_row=6)
display_sample_batch()
LATENT_SPACE_DIM = 4000
HIDDEN_LAYER_SIZE = 64
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
# self.resnet = models.resnet50(weights='ResNet50_Weights.DEFAULT')
# self.conv_t1 = nn.ConvTranspose2d(256, 64, kernel_size=(2, 2), stride=(2, 2), padding=0, output_padding=0 )
# self.conv_1 = nn.Conv2d(128, 90, kernel_size=5, padding=1)
# self.batch_norm_1 = nn.BatchNorm2d(90)
# self.conv_2 = nn.Conv2d(90, 128, kernel_size=5, padding=1)
# self.batch_norm_2 = nn.BatchNorm2d(128)
# self.conv_t2 = nn.ConvTranspose2d(128, 3, kernel_size=5, stride=4, padding=0, output_padding=0 )
# self.conv_final = nn.Conv2d(3, 3, 1)
# self.num_classes = 5
# (128, 128) -> (64, 64)
self.conv1 = nn.Conv2d(3, HIDDEN_LAYER_SIZE, kernel_size=4, padding=1, stride=2)
# (64, 64) -> (32, 32)
self.conv2 = nn.Conv2d(HIDDEN_LAYER_SIZE, HIDDEN_LAYER_SIZE, kernel_size=4, padding=1, stride=2)
# (32, 32) -> (16, 16)
self.conv3 = nn.Conv2d(HIDDEN_LAYER_SIZE, 2 * HIDDEN_LAYER_SIZE, kernel_size=4, padding=1, stride=2) # 16x16 => 8x8
# (16, 16) -> (8, 8)
self.conv4 = nn.Conv2d(2 * HIDDEN_LAYER_SIZE, 4 * HIDDEN_LAYER_SIZE, kernel_size=4, padding=1, stride=2)
# (8, 8) -> (4, 4)
self.conv5 = nn.Conv2d(4 * HIDDEN_LAYER_SIZE, 8*HIDDEN_LAYER_SIZE, kernel_size=4, padding=1, stride=2) # 8x8 => 4x4
self.flatten = nn.Flatten()
# (4, 4) -> (Latence_Space_Dim)
self.linear = nn.Linear(4*4 * (8* HIDDEN_LAYER_SIZE), LATENT_SPACE_DIM)
def forward(self, inp):
# print(inp.shape)
inp = F.leaky_relu(self.conv1(inp))
# print(inp.shape)
inp = F.leaky_relu(self.conv2(inp))
# print(inp.shape)
inp = F.leaky_relu(self.conv3(inp))
# print(inp.shape)
inp = F.leaky_relu(self.conv4(inp))
# print(inp.shape)
inp = F.leaky_relu(self.conv5(inp))
inp = self.flatten(inp)
# print(inp.shape)
inp = F.leaky_relu(self.linear(inp))
# Encoder
# original_inp = inp
# inp = self.resnet.conv1(inp)
# inp = self.resnet.bn1(inp)
# inp = self.resnet.relu(inp)
# skip = inp.clone()
# inp = self.resnet.maxpool(inp)
# inp = self.resnet.layer1(inp)
# inp = self.resnet.layer2(inp)
# inp = self.resnet.layer3(inp)
# inp = self.resnet.layer4(inp)
# print("fully encoded shape is", inp.shape)
return inp
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
# self.resnet = models.resnet50(weights='ResNet50_Weights.DEFAULT')
# (LATENT_SPACE_DIM) -> (4 , 4)
self.linear = nn.Linear(LATENT_SPACE_DIM, 4*4*8*HIDDEN_LAYER_SIZE)
# (4, 4) -> (8 , 8)
self.conv_t1 = nn.ConvTranspose2d(8*HIDDEN_LAYER_SIZE, 4*HIDDEN_LAYER_SIZE, kernel_size=(3, 3), stride=(2, 2), padding=1, output_padding=1)
self.batch_norm_1 = nn.BatchNorm2d(4*HIDDEN_LAYER_SIZE)
# (8, 8) -> (16, 16)
self.conv_t2 = nn.ConvTranspose2d(4*HIDDEN_LAYER_SIZE, 2*HIDDEN_LAYER_SIZE , kernel_size=(3, 3), stride=(2, 2), padding=1, output_padding=1)
self.batch_norm_2 = nn.BatchNorm2d(2*HIDDEN_LAYER_SIZE)
# (16, 16) -> (32, 32)
self.conv_t3 = nn.ConvTranspose2d(2*HIDDEN_LAYER_SIZE, HIDDEN_LAYER_SIZE, kernel_size=(3, 3), stride=(2, 2), padding=1, output_padding=1)
self.batch_norm_3 = nn.BatchNorm2d(HIDDEN_LAYER_SIZE)
# (32, 32) -> (64, 64)
self.conv_t4 = nn.ConvTranspose2d(HIDDEN_LAYER_SIZE, 3, kernel_size=(3, 3), stride=(2, 2), padding=1, output_padding=1)
# self.batch_norm_4 = nn.BatchNorm2d(3)
# self.conv = nn.Conv2d(HIDDEN_LAYER_SIZE, 3, kernel_size=4, stride=1, padding=1)
# self.conv_t5 = nn.ConvTranspose2d(LATENT_SPACE_DIM // 64, 32, kernel_size=(2, 2), stride=(2, 2), padding=0, output_padding=1)
# self.batch_norm_5 = nn.BatchNorm2d(32)
# nn.Upsample(scale_factor = 2, mode='bilinear'),
# nn.ReflectionPad2d(1),
# nn.Conv2d(ngf * mult, int(ngf * mult / 2),
# kernel_size=3, stride=1, padding=0)
def forward(self, inp):
inp = self.linear(inp)
# Reshape back to image size
# inp.shape[0] gives the batch size
# We want a 4x4 image shape
inp = inp.reshape(inp.shape[0], -1, 4, 4)
# print(inp.shape)
inp = self.conv_t1(inp)
inp = self.batch_norm_1(inp)
inp = F.relu(inp)
# print(inp.shape)
inp = self.conv_t2(inp)
inp = self.batch_norm_2(inp)
inp = F.relu(inp)
# print(inp.shape)
inp = self.conv_t3(inp)
inp = self.batch_norm_3(inp)
inp = F.relu(inp)
# print("here 1", inp.shape)
inp = self.conv_t4(inp)
# inp = self.batch_norm_4(inp)
# inp = F.relu(inp)
# print("here 2", inp.shape)
# inp = self.conv_t5(inp)
# inp = self.batch_norm_5(inp)
# inp = F.relu(inp)
# print(inp.shape)
# inp = self.conv(inp)
inp = torch.sigmoid(inp)
# print("image size after decode is", inp.shape)
# inp = T.Resize((IMAGE_SIZE, IMAGE_SIZE))(inp)
return inp
def predict_fill_area_for_img(img, encoder, decoder):
img = img.to(device)
img = img.unsqueeze(0)
embedding = encoder(img)
out = decoder(embedding)[0]
return out.cpu()
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
sys.path.append(module_path)
def get_optimizer(encoder, decoder):
optim = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.001)
return optim
ssim_loss = pytorch_ssim.SSIM()
l1_loss = nn.L1Loss()
l2_loss = nn.MSELoss()
SSIM_LOSS_RATIO = 0.1
def compute_ssim_loss(output, target):
ssim_loss = (pytorch_ssim.ssim(output, target) + 1) / 2
return ssim_loss
def l2andssim_loss(output, target):
# SSIM is between -1 and 1, add 1 and divide by 2 to get rid of negativies
return (1- SSIM_LOSS_RATIO)*l2_loss(output, target) + SSIM_LOSS_RATIO*compute_ssim_loss(output, target)
def get_infilled_image(img, encoder, decoder):
fill_area = predict_fill_area_for_img(img, encoder, decoder)
with_center = AddImageCenter(IMAGE_SIZE, CENTER_SIZE)(torch.clone(img), fill_area)
return with_center
def show_sample_photos(encoder, decoder, title="Sample Batch Photos",):
images = []
for sample_image_index in range(0,3):
cutout_image = sample_cutout[sample_image_index]
predicted_image = get_infilled_image(cutout_image, encoder, decoder)
images += [predicted_image, sample_gt[sample_image_index]]
display_images(images, images_per_row=2, title=title)
def get_val_loss(encoder, decoder, loss_fn):
with torch.no_grad():
total_loss = 0
number_of_images = len(val_loader)
for batch in val_loader:
images, gt_images, _ = batch
images = images.to(device)
gt_images = gt_images.to(device)
embeded_images = encoder(images)
decoded_images = decoder(embeded_images)
total_loss += loss_fn(get_center_of_image(gt_images), decoded_images).item()
return total_loss / number_of_images
def display_loss_history_graph(train_loss_history, val_loss_history, title="Training and Validation Loss" ):
train_loss_history = np.array(train_loss_history)
val_loss_history = np.array(val_loss_history)
plt.plot(train_loss_history[:, 0], train_loss_history[:, 1], label="Training Loss")
plt.plot(val_loss_history[:, 0], val_loss_history[:, 1], label="Validation Loss")
plt.title(title)
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
from torch.optim.lr_scheduler import ReduceLROnPlateau
def do_training(epochs, train_loss_history, val_loss_history, data_loader, optim, encoder, decoder, loss_fn, val_interval = 10, should_save = False):
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, verbose=True, T_max=10)
# scheduler2 = ReduceLROnPlateau(optim, mode='min', factor=0.1, patience=10, verbose=True)
best_val = 999999
for epoch in range(epochs):
encoder = encoder.to(device)
decoder = decoder.to(device)
encoder.train()
decoder.train()
epoch_loss = 0
total_number_of_images = len(data_loader)
pbar = tqdm(enumerate(data_loader), total=total_number_of_images)
for i, batch in pbar:
images, gt_images, _ = batch
images = images.to(device)
gt_images = gt_images.to(device)
optim.zero_grad()
embedded_images = encoder(images)
decoded_images = decoder(embedded_images)
loss = loss_fn(get_center_of_image(gt_images), decoded_images)
loss.backward()
optim.step()
epoch_loss += loss.item()
epoch_loss /= total_number_of_images
train_loss_history.append([epoch, epoch_loss])
if epoch % val_interval == 0:
val_loss = get_val_loss(encoder, decoder, loss_fn)
print(f"Val Loss: {val_loss}")
val_loss_history.append([epoch, val_loss])
if val_loss < best_val and should_save:
save_model(encoder, decoder, epoch, optim, f"./models/model-{epoch}-{val_loss:.2f}.pth")
best_val = val_loss
# scheduler2.step(val_loss)
if epoch % 5 == 0:
clear_output()
display_loss_history_graph(train_loss_history, val_loss_history, title="Sanity Training and Validation Loss")
plt.show()
# scheduler.step()
print(f"Epoch Loss: {epoch_loss}")
sanity_enc = Encoder()
sanity_dec = Decoder()
sanity_optim = get_optimizer(sanity_enc, sanity_dec)
sanity_train_loss_history = []
sanity_val_loss_history = []
# do_training(
# epochs=100,
# train_loss_history=sanity_train_loss_history,
# val_loss_history=sanity_val_loss_history,
# data_loader=sanity_loader,
# optim=sanity_optim,
# encoder=sanity_enc,
# decoder=sanity_dec,
# loss_fn=l1_loss
# )
sanity_img, sanity_gt, sanity_class = next(iter(sanity_loader))
display_images([sanity_img[0], get_infilled_image(sanity_img[0], sanity_enc, sanity_dec), sanity_gt[0]], images_per_row=3)
load_model(sanity_enc, sanity_dec, sanity_optim, "./models/model-3150-1.85.pth")
display_loss_history_graph(sanity_train_loss_history, sanity_val_loss_history, title="Sanity Training and Validation Loss")
enc = Encoder()
dec = Decoder()
optim = get_optimizer(enc, dec)
# enc = nn.DataParallel(enc)
# dec = nn.DataParallel(dec)
train_loss_history = []
val_loss_history = []
do_training(
epochs=5000,
train_loss_history=train_loss_history,
val_loss_history=val_loss_history,
data_loader=train_loader,
optim=optim,
encoder=enc,
decoder=dec,
should_save=True,
loss_fn=l1_loss
)
display_loss_history_graph(sanity_train_loss_history, sanity_val_loss_history, title="Sanity Training and Validation Loss")
class SobelFilter(nn.Module):
def __init__(self):
super(SobelFilter, self).__init__()
self.sobel_filter_x = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False)
self.sobel_filter_y = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False)
sobel_kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3).broadcast_to((1, 3, 3, 3))
sobel_kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3).broadcast_to((1, 3, 3, 3))
self.sobel_filter_x.weight = nn.Parameter(sobel_kernel_x)
self.sobel_filter_y.weight = nn.Parameter(sobel_kernel_y)
def forward(self, image):
grad_output_x = self.sobel_filter_x(image)
grad_output_y = self.sobel_filter_y(image)
return grad_output_x, grad_output_y
sobel_filter = SobelFilter()
sobel_filter = sobel_filter.to(device)
sample_batch_img, sample_batch_img_gt, sample_batch_classes = next(iter(val_loader))
sample_train_batch_img, sample_train_batch_img_gt, sample_traing_batch_classes = next(iter(train_loader))
images = []
sobel_filter_sample = sobel_filter(sample_batch_img_gt.to(device))
for i in range(4, 8):
images += [
sobel_filter_sample[0].cpu()[i],
sobel_filter_sample[1].cpu()[i],
sample_batch_img_gt[i]
]
display_images(images, images_per_row=3, title="Sobel Filter Results")
# Compute a loss based on the difference in edges of the images
def edge_diff_loss(img1, img2):
grad_x_img_1, grad_y_img_1 = sobel_filter(img1)
grad_x_img_2, grad_y_img_2 = sobel_filter(img2)
grad_diff_x = torch.abs(grad_x_img_1 - grad_x_img_2)
grad_diff_y = torch.abs(grad_y_img_1 - grad_y_img_2)
loss = torch.mean(grad_diff_x + grad_diff_y)
return loss
def l1_and_sharpness_loss(output, target):
# L1_LAMBDA = 10
# SHARPNESS_LAMBDA = 0.5
# SSIM_LAMBDA = 0.8
L1_LAMBDA = 10
SHARPNESS_LAMBDA = 0
SSIM_LAMBDA = 1
l1_loss_value = l1_loss(output, target)
edge_diff_loss_value = edge_diff_loss(output, target)
ssim_loss_value = ssim_loss(output, target)
# print("L1 Loss:", l1_loss_value)
# print("Edge Diff Loss:", edge_diff_loss_value)
# print("SSIM Loss:", ssim_loss_value)
# print("L1 Loss:", L1_LAMBDA * l1_loss_value)
# print("Edge Diff Loss:", SHARPNESS_LAMBDA * edge_diff_loss_value)
# print("SSIM Loss:", ssim_loss_value * SSIM_LAMBDA)
return L1_LAMBDA*l1_loss_value+ SHARPNESS_LAMBDA*edge_diff_loss_value+ SSIM_LAMBDA*ssim_loss_value
sharpness_enc = Encoder()
sharpness_dec = Decoder()
optim = get_optimizer(sharpness_enc, sharpness_dec)
sharpness_train_loss_history = []
sharpness_val_loss_history = []
do_training(
epochs=10000,
train_loss_history=sharpness_train_loss_history,
val_loss_history=sharpness_val_loss_history,
data_loader=train_loader,
optim=optim,
encoder=sharpness_enc,
decoder=sharpness_dec,
should_save=True,
loss_fn=l1_and_sharpness_loss
)
# sanity_img, sanity_gt, sanity_class = next(iter(train_loader))
display_images([sample_train_batch_img[0], get_infilled_image(sample_train_batch_img[0], sharpness_enc, sharpness_dec), sample_train_batch_img[0]], images_per_row=3)
# load_model(sharpness_enc, sharpness_dec, optim, "./models/SSIMOnlymodel-730-1.40.pth")
# load_model(sharpness_enc, sharpness_dec, optim, "./models/only-sharpness-model-130-1.42.pth")
# load_model(sharpness_enc, sharpness_dec, optim, "./models/l1-sharpness-model-710-1.72.pth")
# load_model(sharpness_enc, sharpness_dec, optim, "./models/only-sharpness-model-1850-1.58.pth")
# load_model(sharpness_enc, sharpness_dec, optim, "./models/only-sharpness-model-2010-1.58.pth")
# load_model(sharpness_enc, sharpness_dec, optim, "./models/l2-sharpness-model-3360-0.90.pth")
load_model(sharpness_enc, sharpness_dec, optim, "./models/high-l2-sharpness-model-9770-17.86.pth")
# model-610-1.42.pth
9770
# indexes = [20, 35, 34, 23, 24]
images = []
for index in range(0, 35):
# for index in indexes:
images += [sample_train_batch_img[index], get_infilled_image(sample_train_batch_img[index], sharpness_enc, sharpness_dec), sample_train_batch_img_gt[index]]
# images += [sample_batch_img[index], get_infilled_image(sample_batch_img[index], sharpness_enc, sharpness_dec), sample_batch_img_gt[index]]
display_images(images, images_per_row=3, title="High SSIM Loss Validation Images")
def get_pca(data):
data = data.numpy()
pca = PCA(n_components=2)
data_2d = pca.fit_transform(data)
return data_2d
def visulize_encoder_embeddings(images, classes, encoder, by_class = False):
with torch.no_grad():
img_batch = images.to(device)
embedded = encoder(img_batch).detach().cpu()
resnet_pca = get_pca(embedded)
plt.figure(figsize=(10, 8))
if by_class:
for idx, name in SELECTED_CLASSES.items():
r, g, b = torch.rand(3)
plt.scatter(resnet_pca[classes == idx, 0], resnet_pca[classes == idx, 1], color=[(r, g, b)], label=name)
else:
avg_color = torch.mean(images, axis=[2,3])
avg_color = avg_color.tolist()
plt.scatter(resnet_pca[:, 0], resnet_pca[:, 1], color=avg_color)
by_txt = "Image Class " if by_class else "Average Color"
plt.legend()
plt.title("PCA of Embeddings Visualized By "+ by_txt )
plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")
plt.grid(True)
plt.show()
visulize_encoder_embeddings(sample_batch_img, sample_batch_classes, enc, by_class=True)
visulize_encoder_embeddings(sample_batch_img, sample_batch_classes, enc)
WARNING:matplotlib.legend:No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
class ResNetEncoder(nn.Module):
def __init__(self):
super(ResNetEncoder, self).__init__()
self.resnet = models.resnet50(weights='ResNet50_Weights.DEFAULT')
def forward(self, inp):
inp = self.resnet.conv1(inp)
inp = self.resnet.bn1(inp)
inp = self.resnet.relu(inp)
inp = self.resnet.maxpool(inp)
inp = self.resnet.layer1(inp)
inp = self.resnet.layer2(inp)
inp = self.resnet.layer3(inp)
inp = self.resnet.layer4(inp)
inp = self.resnet.avgpool(inp)
inp = inp.reshape((-1, 2048))
return inp
res_net_encoder = ResNetEncoder()
res_net_encoder = res_net_encoder.to(device)
visulize_encoder_embeddings(sample_batch_img, sample_batch_classes, res_net_encoder, by_class=True)
visulize_encoder_embeddings(sample_batch_img, sample_batch_classes, res_net_encoder)
WARNING:matplotlib.legend:No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
def unsharp_mask(img):
blurred = torchvision.transforms.GaussianBlur(3, sigma=(7, 7))(img)
unsharp = img - 1.0*blurred
sharp = img + 10*unsharp
return sharp
def get_infilled_image_unsharp_mask(img, encoder, decoder):
fill_area = predict_fill_area_for_img(img, encoder, decoder)
sharp_fill_area = unsharp_mask(fill_area)
sharp_with_center = AddImageCenter(IMAGE_SIZE, CENTER_SIZE)(torch.clone(img), sharp_fill_area)
with_center = AddImageCenter(IMAGE_SIZE, CENTER_SIZE)(torch.clone(img), fill_area)
blurred = torchvision.transforms.GaussianBlur(5, sigma=(5, 5))(sharp_with_center)
return with_center, sharp_with_center, blurred
def visualize_unsharp_mask(encoder, decoder):
images = []
for i in range(5, 10):
fill_area = predict_fill_area_for_img(sample_train_batch_img[i], encoder, decoder)
unsharped = unsharp_mask(fill_area)
images += [fill_area, unsharped]
display_images(images)
visualize_unsharp_mask(sharpness_enc, sharpness_dec)
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
images = []
for index in range(0, 10):
# for index in indexes:
# img, img_gt, = sample_train_batch_img[index], sample_train_batch_img_gt[index]
img, img_gt, = sample_batch_img[index], sample_batch_img_gt[index]
sharp, unsharp, blurred = get_infilled_image_unsharp_mask(img, sharpness_enc, sharpness_dec)
images += [img, sharp, unsharp, blurred, img_gt]
# images += [sample_batch_img[index], get_infilled_image(sample_batch_img[index], sharpness_enc, sharpness_dec), sample_batch_img_gt[index]]
display_images(images, images_per_row=5, title="Unsharp Mask iImages")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).